#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Run thrush-prefiltered AIC signatures for selected benchmarks.

- Loads stacked feature matrix & chunk_ids from:XXXX
  
- Loads performance matrix CSV (MMLU/BBH/MBPP/IFEval) via --perf-csv.
- Filters target columns using --task-prefixes and/or --include/--exclude.
- Handles modulo slicing over targets via --N and --part (done in Python).
- Writes per-benchmark signatures to:
    <sig_dir>/<BENCH>_signatures.csv
"""

import os, re, argparse
from pathlib import Path
from typing import Dict, Optional, List, Iterable

import numpy as np
import pandas as pd
import statsmodels.api as sm
from sklearn.preprocessing import StandardScaler
import pyarrow.parquet as pq

# ──────────────────────────────────────────────────────────────────────────────
# Paths & defaults
# ──────────────────────────────────────────────────────────────────────────────
BASE      = Path("XXXXX")
FM_BASE   = BASE / "output" / "feature_matrix"
PFILT     = FM_BASE / "parts_filtered"
PER_MODEL = FM_BASE / "per_model"  # (optional, not needed if chunkid2text exists)

DEFAULT_PERF = BASE / "output" / "benchmark_performance" / "mmlu_performance_matrix.csv"

_pat_pm = re.compile(r"^(.*?)__part(\d+)\.parquet$")

# Thrush config
THRUSH_PCT  = 0.01
DIRECTION   = "both"
VERBOSE_AIC = False

# ──────────────────────────────────────────────────────────────────────────────
# Thrush + AIC utilities
# ──────────────────────────────────────────────────────────────────────────────
def get_thrush_rank_correlations(cov_df: pd.DataFrame, resp: dict) -> pd.Series:
    common = sorted(set(cov_df.columns) & set(resp.keys()))
    if len(common) < 2:
        raise ValueError("Need ≥2 common models")
    X = cov_df[common]
    y = pd.Series(resp, dtype=float)[common]
    ry = y.rank(method="first").to_numpy(dtype=np.float64)
    M = ry.size
    v = 2.0 * ry - (M + 1.0)
    R = X.T.rank(axis=0, method="first").to_numpy(dtype=np.float64)  # (M x F)
    thr = 2.0 * (R.T @ v)  # (F,)
    return pd.Series(thr, index=cov_df.index, name="thrush_correlation")

def _select_candidates_by_thrush_series(
    thr_series: pd.Series, pct: float = 0.01, min_candidates: int = 50
) -> List[int]:
    s = thr_series.dropna().astype(float)
    n = len(s)
    if n == 0:
        return []
    pct = float(max(0.0, min(0.5, pct)))
    k = max(1, int(np.floor(pct * n)))
    top_idx = s.nlargest(k).index.to_list()
    bot_idx = s.nsmallest(k).index.to_list()
    cand = list(dict.fromkeys(bot_idx + top_idx))
    if len(cand) < min_candidates:
        need = min_candidates - len(cand)
        remain = s[~s.index.isin(cand)]
        extra_top = remain.nlargest((need + 1) // 2).index.to_list()
        extra_bot = remain.nsmallest(need // 2).index.to_list()
        cand = list(dict.fromkeys(bot_idx + extra_bot + top_idx + extra_top))
    return [int(i) for i in cand]

def _build_Xy(
    cov_df: pd.DataFrame,
    resp: dict,
    allowed_rows: Optional[Iterable[int]] = None,
    drop_nan_rows: bool = True,
):
    common = sorted(set(cov_df.columns) & set(resp.keys()))
    if len(common) < 2:
        raise ValueError("Need ≥2 common models")
    if allowed_rows is not None:
        allowed_rows = np.array(sorted(set(int(i) for i in allowed_rows)), dtype=int)
        cov_df = cov_df.reset_index(drop=True).iloc[allowed_rows].reset_index(drop=True)
        row_pos_full_map = allowed_rows.copy()
    else:
        cov_df = cov_df.reset_index(drop=True)
        row_pos_full_map = np.arange(len(cov_df), dtype=int)

    if drop_nan_rows:
        keep_mask = ~cov_df[common].isna().any(axis=1)
        cov_df = cov_df.loc[keep_mask].reset_index(drop=True)
        row_pos_full_kept = row_pos_full_map[keep_mask.to_numpy()]
    else:
        row_pos_full_kept = row_pos_full_map

    X_DN = cov_df[common].to_numpy(dtype=np.float64, copy=False)
    X_DN = np.where(np.isfinite(X_DN), X_DN, np.nan)
    mask = ~np.isnan(X_DN).any(axis=1)
    if not mask.all():
        X_DN = X_DN[mask]
        row_pos_full_kept = row_pos_full_kept[mask]

    var = np.nanvar(X_DN, axis=1)
    mask = var > 0
    if not mask.all():
        X_DN = X_DN[mask]
        row_pos_full_kept = row_pos_full_kept[mask]

    X = X_DN.T
    y = pd.Series(resp, dtype=float)[common].to_numpy(np.float64)

    scaler = StandardScaler(with_mean=True, with_std=True)
    Xs = scaler.fit_transform(X).astype(np.float64)
    y_center = (y - y.mean()).astype(np.float64)
    return Xs, y_center, row_pos_full_kept

def _aic_subset(Xs: np.ndarray, y: np.ndarray, cols: List[int]) -> float:
    if len(cols) == 0:
        X = np.ones((Xs.shape[0], 1), dtype=np.float64)
    else:
        X = np.column_stack([np.ones(Xs.shape[0], dtype=np.float64), Xs[:, cols]])
    try:
        return float(sm.OLS(y, X).fit().aic)
    except Exception:
        return np.inf

def _stepwise_aic(
    Xs: np.ndarray,
    y: np.ndarray,
    direction: str = "both",
    tol: float = 1e-6,
    max_features: int | None = 1500,
    min_delta: float = 0.0,
    verbose_mode: str = "brief",
    print_every: int = 200,
) -> List[int]:
    N, D = Xs.shape
    hard_cap = max(0, min(N - 1, D))
    if max_features is None:
        max_features = hard_cap
    else:
        max_features = max(0, min(int(max_features), hard_cap))

    selected, remaining = [], list(range(D))
    cur_aic = _aic_subset(Xs, y, selected)
    step = 0
    moves = []

    while True:
        best_aic, best_move = cur_aic, None
        step += 1
        if direction in ("forward", "both") and len(selected) < max_features:
            for c in remaining:
                aic = _aic_subset(Xs, y, selected + [c])
                if (aic + tol) < best_aic and (cur_aic - aic) >= min_delta:
                    best_aic, best_move = aic, ("add", c)

        if direction in ("backward", "both") and selected:
            for c in selected:
                trial = [v for v in selected if v != c]
                aic = _aic_subset(Xs, y, trial)
                if (aic + tol) < best_aic and (cur_aic - aic) >= min_delta:
                    best_aic, best_move = aic, ("drop", c)

        if best_move is None:
            if verbose_mode == "steps":
                print(f"[step {step:02d}] stop (no improvement ≥ {min_delta}).")
            break

        action, col = best_move
        if action == "add":
            selected.append(col); remaining.remove(col)
        else:
            selected.remove(col); remaining.append(col)

        if verbose_mode == "steps" and (step % print_every == 1):
            print(f"[step {step:02d}] {action:>4} {col:6d} | AIC: {cur_aic:.6f} → {best_aic:.6f} (Δ={cur_aic-best_aic:.6f})")
        moves.append((step, action, int(col), float(cur_aic), float(best_aic)))
        cur_aic = best_aic

        if len(selected) >= max_features and direction in ("forward", "both"):
            if verbose_mode == "steps":
                print(f"[step {step:02d}] reached max_features={max_features}; stop.")
            break

    if verbose_mode in ("brief", "steps"):
        added = sum(1 for _, a, _, _, _ in moves if a == "add")
        dropped = sum(1 for _, a, _, _, _ in moves if a == "drop")
        print(f"[AIC] moves: +{added}/-{dropped} | selected={len(selected)} | final AIC={cur_aic:.6f}")
    return sorted(selected)

def thrush_prefiltered_aic_for_benchmark(
    token_cov: pd.DataFrame,
    resp: dict,
    bench_name: str,
    thrush_pct: float = 0.01,
    direction: str = "both",
    sampled_df: Optional[pd.DataFrame] = None,
    verbose: bool = False
) -> pd.DataFrame:
    thr = get_thrush_rank_correlations(token_cov, resp)
    n_models = len(set(token_cov.columns) & set(resp.keys()))
    cand_rows = _select_candidates_by_thrush_series(
        thr_series=thr, pct=thrush_pct, min_candidates=max(2 * n_models, 50)
    )
    if len(cand_rows) == 0:
        return pd.DataFrame(columns=["benchmark","row_pos_full","coef","chunk_text"])

    Xs, y_center, row_pos_full_kept = _build_Xy(token_cov, resp, allowed_rows=cand_rows)
    if Xs.shape[1] == 0:
        return pd.DataFrame(columns=["benchmark","row_pos_full","coef","chunk_text"])

    sel_cols = _stepwise_aic(
        Xs, y_center,
        direction=direction,
        tol=1e-6,
        max_features=1500,
        min_delta=0.0,
        verbose_mode="brief" if verbose else "none",
        print_every=200,
    )
    if len(sel_cols) == 0:
        return pd.DataFrame(columns=["benchmark","row_pos_full","coef","chunk_text"])

    X_design = sm.add_constant(Xs[:, sel_cols], has_constant="add")
    res = sm.OLS(y_center, X_design).fit()
    coefs = res.params[1:].astype(float).tolist()  # exclude intercept

    sel_rows = row_pos_full_kept[sel_cols].astype(int)
    out = pd.DataFrame({
        "benchmark": bench_name,
        "row_pos_full": sel_rows,
        "coef": coefs,
    }).sort_values(by="coef", key=np.abs, ascending=False).reset_index(drop=True)

    if sampled_df is not None and "chunk_text" in sampled_df.columns:
        out["chunk_text"] = sampled_df.iloc[out["row_pos_full"]]["chunk_text"].astype(str).values

    if verbose:
        print(f"{bench_name}: thrush candidates={len(cand_rows)}, usable={Xs.shape[1]}, selected={len(out)}")
    return out

# ──────────────────────────────────────────────────────────────────────────────
# IO helpers
# ──────────────────────────────────────────────────────────────────────────────
def load_feature_matrix_parts():
    matrix_files = sorted(PFILT.glob("part_*_matrix.filtered.parquet"))
    chunk_ids_files = sorted(PFILT.glob("part_*_chunk_ids.filtered.npy"))
    if not matrix_files or not chunk_ids_files:
        raise FileNotFoundError(f"No parts found in {PFILT}")

    dfs = [pq.read_table(f).to_pandas() for f in matrix_files]
    X = pd.concat(dfs, ignore_index=True)
    models = list(dfs[0].columns)
    X_np = X.to_numpy(dtype=np.float32)
    chunk_ids = np.concatenate([np.load(f) for f in chunk_ids_files])
    return X_np, models, chunk_ids

def maybe_load_chunkid2text():
    path = FM_BASE / "chunkid2text.parquet"
    if path.exists():
        df = pd.read_parquet(path)
        if {"chunk_id", "chunk_text"}.issubset(df.columns):
            return df.set_index("chunk_id")["chunk_text"].to_dict()
    return None

def read_perf_matrix(perf_csv_path: str):
    df = pd.read_csv(perf_csv_path)
    # First column is model names (e.g., 'Unnamed: 0'); keep that as index
    if "model" in df.columns:
        df = df.set_index("model")
    else:
        first_col = df.columns[0]
        if df[first_col].dtype == object:
            df = df.set_index(first_col)
        else:
            raise ValueError("Couldn't find a 'model' identifier column in performance matrix.")
    return df

# ──────────────────────────────────────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────────────────────────────────────
def parse_args():
    ap = argparse.ArgumentParser(description="Run thrush-prefiltered AIC signatures.")
    ap.add_argument("--perf-csv", type=str, default=str(DEFAULT_PERF),
                    help="Path to performance matrix CSV (mmlu/bbh/mbpp/ifeval).")
    ap.add_argument("--task-prefixes", type=str, default="mmlu_",
                    help="Comma-separated prefixes (e.g., 'mmlu_' or 'bbh_'). Leave empty to skip.")
    ap.add_argument("--include", type=str, default="",
                    help="Comma-separated exact column names to include (overrides prefix filter).")
    ap.add_argument("--exclude", type=str, default="",
                    help="Comma-separated exact column names to exclude.")
    ap.add_argument("--only-remaining", action="store_true",
                    help="Run only those without existing signature CSVs.")
    ap.add_argument("--list-benchmarks", action="store_true",
                    help="Print the final (post-split) benchmark list and exit.")
    ap.add_argument("--sig-dir", type=str, default=str(BASE / "benchmark_signature"),
                    help="Directory to write signatures CSVs.")
    ap.add_argument("--verbose", action="store_true", help="Verbose AIC logs.")

    # NEW: modulo slicing handled in Python
    ap.add_argument("--N", type=int, default=1, help="Total parts to split the target list.")
    ap.add_argument("--part", type=int, default=0, help="This part index (0..N-1).")
    return ap.parse_args()

def signatures_path(sig_dir: Path, bench_name: str) -> Path:
    return sig_dir / f"{bench_name}_signatures.csv"

def main():
    args = parse_args()
    SIG_DIR = Path(args.sig_dir)
    SIG_DIR.mkdir(parents=True, exist_ok=True)

    # 1) Load feature matrix and models
    X, models, chunk_ids = load_feature_matrix_parts()
    token_cov = pd.DataFrame(X, columns=models)

    # optional text
    chunkid2text = maybe_load_chunkid2text()
    def get_chunk_text(cid: int) -> Optional[str]:
        return None if chunkid2text is None else chunkid2text.get(int(cid))

    # 2) Load performance matrix
    perf_df = read_perf_matrix(args.perf_csv)

    # 3) Candidate benchmarks via prefixes and include/exclude
    prefixes = [p.strip() for p in args.task_prefixes.split(",") if p.strip()]
    if prefixes:
        by_prefix = [c for c in perf_df.columns
                     if isinstance(c, str) and any(c.startswith(p) for p in prefixes)]
    else:
        by_prefix = list(perf_df.columns)

    include = [s.strip() for s in args.include.split(",") if s.strip()]
    exclude = {s.strip() for s in args.exclude.split(",") if s.strip()}

    if include:
        target = [b for b in include if b in perf_df.columns]
    else:
        target = [b for b in by_prefix if b not in exclude]

    # helpful default: if using bbh_, drop overall 'bbh' column
    if any(p.startswith("bbh_") for p in prefixes):
        target = [b for b in target if b != "bbh"]

    # 4) Only remaining
    if args.only_remaining:
        target = [b for b in target if not signatures_path(SIG_DIR, b).exists()]

    target = sorted(target)

    # 5) Apply modulo split here (Python handles N/part)
    N = max(1, int(args.N))
    part = int(args.part) % N
    if N > 1:
        target = [b for i, b in enumerate(target) if (i % N) == part]

    if args.list_benchmarks:
        print(",".join(target))
        return

    if not target:
        print("Nothing to run (empty target list).")
        return

    # 6) Iterate and run
    for BENCH_NAME in target:
        out_path = signatures_path(SIG_DIR, BENCH_NAME)
        if out_path.exists():
            print(f"{BENCH_NAME}: already exists → {out_path} (skip)")
            continue

        resp_series = perf_df[BENCH_NAME].dropna()

        # Mapping info (printed from Python, not bash)
        available_models = [m for m in token_cov.columns if m in resp_series.index]
        mapped = sorted(available_models)
        print(f"[{BENCH_NAME}] mapped models: {len(mapped)}")
        print(f"[{BENCH_NAME}] mapped list: {', '.join(mapped) if mapped else '(none)'}")
        missing_in_cov = sorted(set(resp_series.index) - set(token_cov.columns))
        if missing_in_cov:
            print(f"[{BENCH_NAME}] models in perf but not in coverage ({len(missing_in_cov)}): {', '.join(missing_in_cov)}")

        if len(available_models) < 2:
            print(f"{BENCH_NAME}: skipped (need ≥2 overlapping models)")
            continue

        print(f"\nRunning signatures for benchmark: {BENCH_NAME}")
        resp: Dict[str, float] = resp_series[available_models].astype(float).to_dict()

        df_sel = thrush_prefiltered_aic_for_benchmark(
            token_cov=token_cov,
            resp=resp,
            bench_name=BENCH_NAME,
            thrush_pct=THRUSH_PCT,
            direction=DIRECTION,
            sampled_df=None,
            verbose=VERBOSE_AIC or args.verbose,
        )

        if df_sel.empty:
            print(f"{BENCH_NAME}: no features selected.")
            continue

        # add chunk_id and (optional) text
        df_sel["chunk_id"] = df_sel["row_pos_full"].map(lambda i: int(chunk_ids[int(i)]))
        df_sel = df_sel.drop(columns=["row_pos_full"])
        if chunkid2text is not None:
            df_sel["chunk_text"] = df_sel["chunk_id"].map(get_chunk_text)

        base_cols = ["benchmark", "chunk_id", "coef"]
        if "chunk_text" in df_sel.columns:
            base_cols.append("chunk_text")
        df_out = df_sel[base_cols]

        df_out.to_csv(out_path, index=False)
        print(f"[Saved] {BENCH_NAME} signatures → {out_path}")

if __name__ == "__main__":
    main()
